{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "5eb0fa8d-51a3-4b32-86c6-0df4baead164",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Retriever Query Engine with Custom Retrievers - Simple Hybrid Search\n",
    "\n",
    "In this tutorial, we show you how to define a very simple version of hybrid search! \n",
    "\n",
    "Combine keyword lookup retrieval with vector retrieval using \"AND\" and \"OR\" conditions."
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "9393b260-57e9-4134-8793-b3423bc891ca",
   "metadata": {},
   "source": [
    "### Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "c2c76626-5f96-4d60-8050-7437e3270363",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:numexpr.utils:Note: NumExpr detected 16 cores but \"NUMEXPR_MAX_THREADS\" not set, so enforcing safe limit of 8.\n",
      "Note: NumExpr detected 16 cores but \"NUMEXPR_MAX_THREADS\" not set, so enforcing safe limit of 8.\n",
      "INFO:numexpr.utils:NumExpr defaulting to 8 threads.\n",
      "NumExpr defaulting to 8 threads.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/loganm/miniconda3/envs/llama-index/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "import logging\n",
    "import sys\n",
    "\n",
    "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n",
    "logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))\n",
    "\n",
    "from llama_index import (\n",
    "    VectorStoreIndex,\n",
    "    SimpleKeywordTableIndex,\n",
    "    SimpleDirectoryReader,\n",
    "    ServiceContext,\n",
    "    StorageContext,\n",
    ")\n",
    "from IPython.display import Markdown, display"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "ec6a93b3-04e1-4c21-95d8-b93873e68fad",
   "metadata": {},
   "source": [
    "### Load Data\n",
    "\n",
    "We first show how to convert a Document into a set of Nodes, and insert into a DocumentStore."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f1f2276b-03f3-4dd9-9178-ed274d465c17",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# load documents\n",
    "documents = SimpleDirectoryReader(\"../data/paul_graham\").load_data()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "571a5862-0f6f-42b6-9975-2426f1ee8b8f",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# initialize service context (set chunk size)\n",
    "service_context = ServiceContext.from_defaults(chunk_size=1024)\n",
    "node_parser = service_context.node_parser\n",
    "\n",
    "nodes = node_parser.get_nodes_from_documents(documents)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "e61b9f42-a8a0-491f-8342-3cc366689c41",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# initialize storage context (by default it's in-memory)\n",
    "storage_context = StorageContext.from_defaults()\n",
    "storage_context.docstore.add_documents(nodes)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "88fb9077-c4aa-4a52-ac8d-75542ab27501",
   "metadata": {},
   "source": [
    "### Define Vector Index and Keyword Table Index over Same Data\n",
    "\n",
    "We build a vector index and keyword index over the same DocumentStore"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "b2f3bb70-0c1d-42c0-9b66-c73c0056339f",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n",
      "> [build_index_from_nodes] Total LLM token usage: 0 tokens\n",
      "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 17050 tokens\n",
      "> [build_index_from_nodes] Total embedding token usage: 17050 tokens\n",
      "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total LLM token usage: 0 tokens\n",
      "> [build_index_from_nodes] Total LLM token usage: 0 tokens\n",
      "INFO:llama_index.token_counter.token_counter:> [build_index_from_nodes] Total embedding token usage: 0 tokens\n",
      "> [build_index_from_nodes] Total embedding token usage: 0 tokens\n"
     ]
    }
   ],
   "source": [
    "vector_index = VectorStoreIndex(nodes, storage_context=storage_context)\n",
    "keyword_index = SimpleKeywordTableIndex(nodes, storage_context=storage_context)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "6df55ba7-8070-4149-b786-3d3e64b7d64f",
   "metadata": {},
   "source": [
    "### Define Custom Retriever\n",
    "\n",
    "We now define a custom retriever class that can implement basic hybrid search with both keyword lookup and semantic search.\n",
    "\n",
    "- setting \"AND\" means we take the intersection of the two retrieved sets\n",
    "- setting \"OR\" means we take the union"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "b45b406e-c42c-48d2-bd62-228cf62b3b43",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "# import QueryBundle\n",
    "from llama_index import QueryBundle\n",
    "\n",
    "# import NodeWithScore\n",
    "from llama_index.schema import NodeWithScore\n",
    "\n",
    "# Retrievers\n",
    "from llama_index.retrievers import (\n",
    "    BaseRetriever,\n",
    "    VectorIndexRetriever,\n",
    "    KeywordTableSimpleRetriever,\n",
    ")\n",
    "\n",
    "from typing import List"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "2b17f16c-6059-4855-9417-a00bb41d4c12",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "class CustomRetriever(BaseRetriever):\n",
    "    \"\"\"Custom retriever that performs both semantic search and hybrid search.\"\"\"\n",
    "\n",
    "    def __init__(\n",
    "        self,\n",
    "        vector_retriever: VectorIndexRetriever,\n",
    "        keyword_retriever: KeywordTableSimpleRetriever,\n",
    "        mode: str = \"AND\",\n",
    "    ) -> None:\n",
    "        \"\"\"Init params.\"\"\"\n",
    "\n",
    "        self._vector_retriever = vector_retriever\n",
    "        self._keyword_retriever = keyword_retriever\n",
    "        if mode not in (\"AND\", \"OR\"):\n",
    "            raise ValueError(\"Invalid mode.\")\n",
    "        self._mode = mode\n",
    "\n",
    "    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:\n",
    "        \"\"\"Retrieve nodes given query.\"\"\"\n",
    "\n",
    "        vector_nodes = self._vector_retriever.retrieve(query_bundle)\n",
    "        keyword_nodes = self._keyword_retriever.retrieve(query_bundle)\n",
    "\n",
    "        vector_ids = {n.node.node_id for n in vector_nodes}\n",
    "        keyword_ids = {n.node.node_id for n in keyword_nodes}\n",
    "\n",
    "        combined_dict = {n.node.node_id: n for n in vector_nodes}\n",
    "        combined_dict.update({n.node.node_id: n for n in keyword_nodes})\n",
    "\n",
    "        if self._mode == \"AND\":\n",
    "            retrieve_ids = vector_ids.intersection(keyword_ids)\n",
    "        else:\n",
    "            retrieve_ids = vector_ids.union(keyword_ids)\n",
    "\n",
    "        retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]\n",
    "        return retrieve_nodes"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "id": "d111dc0c-a5f7-44d1-b9b8-fbdb36683040",
   "metadata": {},
   "source": [
    "### Plugin Retriever into Query Engine\n",
    "\n",
    "Plugin retriever into a query engine, and run some queries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "2d328d7e-4537-4007-852d-e575e4d11110",
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "from llama_index import get_response_synthesizer\n",
    "from llama_index.query_engine import RetrieverQueryEngine\n",
    "\n",
    "# define custom retriever\n",
    "vector_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=2)\n",
    "keyword_retriever = KeywordTableSimpleRetriever(index=keyword_index)\n",
    "custom_retriever = CustomRetriever(vector_retriever, keyword_retriever)\n",
    "\n",
    "# define response synthesizer\n",
    "response_synthesizer = get_response_synthesizer()\n",
    "\n",
    "# assemble query engine\n",
    "custom_query_engine = RetrieverQueryEngine(\n",
    "    retriever=custom_retriever,\n",
    "    response_synthesizer=response_synthesizer,\n",
    ")\n",
    "\n",
    "# vector query engine\n",
    "vector_query_engine = RetrieverQueryEngine(\n",
    "    retriever=vector_retriever,\n",
    "    response_synthesizer=response_synthesizer,\n",
    ")\n",
    "# keyword query engine\n",
    "keyword_query_engine = RetrieverQueryEngine(\n",
    "    retriever=keyword_retriever,\n",
    "    response_synthesizer=response_synthesizer,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "6cbc3209-cb4e-4a0a-ada8-1d8eff9e657e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:llama_index.token_counter.token_counter:> [retrieve] Total LLM token usage: 0 tokens\n",
      "> [retrieve] Total LLM token usage: 0 tokens\n",
      "INFO:llama_index.token_counter.token_counter:> [retrieve] Total embedding token usage: 12 tokens\n",
      "> [retrieve] Total embedding token usage: 12 tokens\n",
      "INFO:llama_index.indices.keyword_table.retrievers:> Starting query: What did the author do during his time at YC?\n",
      "> Starting query: What did the author do during his time at YC?\n",
      "INFO:llama_index.indices.keyword_table.retrievers:query keywords: ['time', 'yc', 'author']\n",
      "query keywords: ['time', 'yc', 'author']\n",
      "INFO:llama_index.indices.keyword_table.retrievers:> Extracted keywords: ['time', 'yc']\n",
      "> Extracted keywords: ['time', 'yc']\n",
      "INFO:llama_index.token_counter.token_counter:> [get_response] Total LLM token usage: 1919 tokens\n",
      "> [get_response] Total LLM token usage: 1919 tokens\n",
      "INFO:llama_index.token_counter.token_counter:> [get_response] Total embedding token usage: 0 tokens\n",
      "> [get_response] Total embedding token usage: 0 tokens\n"
     ]
    }
   ],
   "source": [
    "response = custom_query_engine.query(\"What did the author do during his time at YC?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "d3f74212-42b1-40f0-b9aa-98098ad75a22",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "The author worked on YC, wrote essays, worked on a new version of Arc, wrote Hacker News in Arc, wrote YC's internal software in Arc, and dealt with disputes between cofounders, figuring out when people were lying to them, and fighting with people who maltreated the startups.\n"
     ]
    }
   ],
   "source": [
    "print(response)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "48eb7360-644d-408a-8c43-8ddebbdebb64",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:llama_index.token_counter.token_counter:> [retrieve] Total LLM token usage: 0 tokens\n",
      "> [retrieve] Total LLM token usage: 0 tokens\n",
      "INFO:llama_index.token_counter.token_counter:> [retrieve] Total embedding token usage: 11 tokens\n",
      "> [retrieve] Total embedding token usage: 11 tokens\n",
      "INFO:llama_index.indices.keyword_table.retrievers:> Starting query: What did the author do during his time at Yale?\n",
      "> Starting query: What did the author do during his time at Yale?\n",
      "INFO:llama_index.indices.keyword_table.retrievers:query keywords: ['yale', 'time', 'author']\n",
      "query keywords: ['yale', 'time', 'author']\n",
      "INFO:llama_index.indices.keyword_table.retrievers:> Extracted keywords: ['time']\n",
      "> Extracted keywords: ['time']\n",
      "INFO:llama_index.token_counter.token_counter:> [get_response] Total LLM token usage: 0 tokens\n",
      "> [get_response] Total LLM token usage: 0 tokens\n",
      "INFO:llama_index.token_counter.token_counter:> [get_response] Total embedding token usage: 0 tokens\n",
      "> [get_response] Total embedding token usage: 0 tokens\n"
     ]
    }
   ],
   "source": [
    "# hybrid search can allow us to not retrieve nodes that are irrelevant\n",
    "# Yale is never mentioned in the essay\n",
    "response = custom_query_engine.query(\"What did the author do during his time at Yale?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "95da0140-19f7-4533-befb-6153c9bda550",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "None\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(str(response))\n",
    "len(response.source_nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "6507bf10-2c0c-484b-bea9-fbd60354fcad",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:llama_index.token_counter.token_counter:> [retrieve] Total LLM token usage: 0 tokens\n",
      "> [retrieve] Total LLM token usage: 0 tokens\n",
      "INFO:llama_index.token_counter.token_counter:> [retrieve] Total embedding token usage: 11 tokens\n",
      "> [retrieve] Total embedding token usage: 11 tokens\n",
      "INFO:llama_index.token_counter.token_counter:> [get_response] Total LLM token usage: 1871 tokens\n",
      "> [get_response] Total LLM token usage: 1871 tokens\n",
      "INFO:llama_index.token_counter.token_counter:> [get_response] Total embedding token usage: 0 tokens\n",
      "> [get_response] Total embedding token usage: 0 tokens\n"
     ]
    }
   ],
   "source": [
    "# in contrast, vector search will return an answer\n",
    "response = vector_query_engine.query(\"What did the author do during his time at Yale?\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "c1696bac-7978-4e2a-ad1b-84e5b2d5a1d4",
   "metadata": {
    "tags": []
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "The author did not attend Yale. The context information provided is about the author's work before and after college.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print(str(response))\n",
    "len(response.source_nodes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "82fa497c-9017-413d-83b7-52455c5d592d",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llama-index",
   "language": "python",
   "name": "llama-index"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
